library(maptools)
library(tidyverse)
library(sp)
library(sf)
library(raster)
library(rgdal)
library(PNWColors)
library(MetBrewer)
library(viridis)



      #define some functions
      #########################
      #function for adding transparency to any color	
      add.alpha <- function(col, alpha=1){
        if(missing(col))
          stop("Please provide a vector of colours.")
        apply(sapply(col, col2rgb)/255, 2, 
              function(x) 
                rgb(x[1], x[2], x[3], alpha=alpha))  
      }	
      
      get_grid_raster <- function(raster){
        output <- raster(raster)
        output[] <- 1:ncell(output)
        return(output)
      }
      
      identify_non_cells <- function(raster, poly){
        raster = raster(raster)
        raster[] <- 1:ncell(raster)
        r_ext <- extract(raster, poly)
        all_cells <- 1:ncell(raster)
        non_cells <- all_cells[all_cells %in% unlist(r_ext) == F]
        return(non_cells)
        
      }
      #########################



    map <- read_rds("figdat/world_borders_WB.rds")

    #divide africa into regions
    af_w <- map[map$ADM0_NAME %in% c("Côte d'Ivoire","Burkina Faso","Benin","Togo","Ghana","Ivory Coast","Liberia","Sierra Leone","Guinea","Guinea-Bissau","Senegal","Nigeria","Niger","Mali"),]
    af_c <- map[map$ADM0_NAME %in% c("Cameroon","Central African Republic","Chad","Congo","Democratic Republic of Congo","Equatorial Guinea","Gabon","São Tomé and Príncipe"),]
    af <- map[map$ADM0_NAME %in% c(af_w$ADM0_NAME, af_c$ADM0_NAME),]
    af_other <- map[map$ADM0_NAME %in% c(af_w$ADM0_NAME, af_c$ADM0_NAME,"Algeria","Morocco","Western Sahara","Mauritania","Mali","Libya","Tunisia","Arab Republic of Egypt","Ethiopia","Sudan","South Sudan","Kenya","Uganda","Tanzania","Somalia","Eritrea","Angola","Zambia","Mozambique","Malawi","Zimbabwe"),]





    #code section reads in raster data for PM, wealth, and urbanicity and just takes average values for each grid cell and saves.
    #not including raw data here although all is publicly available oneline from those respective authors so could
    #download their data and run this code to generate the average values shown in fig
    
      ################### PM2.5 #########################

          ## All PM ##
        #       
        #       grid <- raster(raster("1_input/data/vandonk/annual_001/V5GL02.HybridPM25.Global.199801-199812.nc"))
        #       grid[] <- 1:ncell(grid)
        # 
        #       
        #       fls <- list.files("1_input/data/vandonk/annual_001/")[13:22][6:10]
        #       
        #       
        #       rdat <- raster::brick(x = grid, nl = length(fls))
        #       
        #       for (i in 1:length(fls)){
        #         
        #         rdat[[i]] <- (raster(paste("1_input/data/vandonk/annual_001/", fls[i], sep = "")))
        #         
        #       }
        #       
        #       
        #       r_af <- raster::brick(x = raster::crop(grid,extent(af)), nl = length(fls))
        # 
        # 
        #       
        #       for (i in 1:length(fls)){
        #         r_af[[i]] <- raster::crop(rdat[[i]], extent(af))
        #       }
        #       
        #       r_af <- raster::stackApply(x=r_af,indices = 1, fun= mean)
        #       
        #       
        #       
        #       
        #       r_af_grid <- get_grid_raster(r_af)
        #       noncell_af <- identify_non_cells(r_af_grid, af)
        #       r_af[noncell_af] <- NA
        #       
        #       pm <- r_af
        #       
        #       
        #   ## no dust no sea salt
        #       
        #       
        #       
        #       grid <- raster(raster("1_input/data/vandonk/annual_no_ss/ACAG_PM25_noDUSTnoSEASALT_GWR_V4GL03_201001_201012_0p01.nc"))
        #       grid <- flip(flip(t(grid), direction = "x"), direction = "y")
        #       grid[] <- 1:ncell(grid)
        #       
        #       
        #       fls <- list.files("1_input/data/vandonk/annual_no_ss/")[16:20]
        #       
        #       
        #       rdat <- raster::brick(x = grid, nl = length(fls))
        #       
        #       for (i in 1:length(fls)){
        #         
        #         rdat[[i]] <- flip(flip(t(raster(paste("1_input/data/vandonk/annual_no_ss/", fls[i], sep = ""))), direction = "x"), direction = "y") 
        #         
        #       }
        #       
        #       
        #       r_af <- raster::brick(x = raster::crop(grid,extent(af)), nl = length(fls))
        #       
        #       
        #       
        #       for (i in 1:length(fls)){
        #         r_af[[i]] <- raster::crop(rdat[[i]], extent(af))
        #       }
        #       
        #       r_af <- raster::stackApply(x=r_af,indices = 1, fun= mean)
        #       
        #       r_af_grid <- get_grid_raster(r_af)
        #       noncell_af <- identify_non_cells(r_af_grid, af)
        #       r_af[noncell_af] <- NA
        #       
        #       pmnd <- r_af
        #       
        #       
        #       
        #       
        # ############### Wealth ######
        #       
        #       #loop over countries and read in wealth data
        #       
        #       wfls <- list.files("1_input/data/rwi/")
        #       wfls_cty <- substr(wfls,1,3)
        #       
        #       wfls <- wfls[wfls_cty %in% af$ISO3166_1_]
        #       
        #       rwi_list <- list()
        # 
        #       for (i in 1:length(wfls)){
        #         rwi_list[[i]] <- read_csv(paste("1_input/data/rwi/",wfls[i],sep=""))
        #         rwi_list[[i]]$country <- wfls_cty[i]
        #       }      
        #       
        #       
        #       rwi <- data.frame(data.table::rbindlist(rwi_list))
        #       
        #       
        #     ######## Urbanciity
        #       
        #       ur <- raster("1_input/data/GIS/rasters/Urban-Rural Catchment Areas (URCA).tif")
        #       ur <- raster::crop(ur, extent(af))
        #       
        #       ur_grid <- get_grid_raster(ur)
        #       noncell_ur <- identify_non_cells(ur_grid, af)
        #       ur[noncell_ur] <- NA
        #       
        #       
        #       
        #       save(pm,pmnd,ur, rwi, file = "2_analysis/data/west-africa-plot-data.RData")
        #       
        #       
      
      
      ############ plot ####################
      
      load("data/west-africa-plot-data.RData")
    
      #define color pals
      pal_pm <- colorRampPalette(colors = MetBrewer::met.brewer("VanGogh2", 8))(2000) %>% rev()
      pal_urca <- magma(2000)
      pal_wealthn <- colorRampPalette(colors = c("red3",MetBrewer::met.brewer("OKeeffe1", 16)[c(1:7)]))(2000) 
      pal_wealthp <- colorRampPalette(colors = c(MetBrewer::met.brewer("OKeeffe1", 16)[c(10:16)], "navy"))(2000) 
          
      #assign rwi color
      intn <- classInt::classIntervals(rwi$rwi[rwi$rwi<=0], style = "fixed", fixedBreaks = c(-100, seq(-1, 0, .1)))
      intp <- classInt::classIntervals(rwi$rwi[rwi$rwi>0], style = "fixed", fixedBreaks = c(seq(0, 1, .1), 100))
          
          rwi$color2 <- NA
          pal_wealth <- colorRampPalette(colors = c("red3",MetBrewer::met.brewer("Benedictus", 16)[c(1:6, 8:13)], "navy"))(2000)
          cty <- unique(rwi$country)
          for(i in 1:length(cty)){
            int <- classInt::classIntervals(rwi$rwi[rwi$country==cty[i]], style = "quantile")
            rwi$color2[rwi$country==cty[i]]<-classInt::findColours(int, pal_wealth)
                      }


          
      
            
      
    
      pdf(file = "figures/Fig3a-d-raw.pdf", width = 20,height = 5)
      
      par(mfrow = c(1,4))
      par(mar = c(2,2,2,2))
      
      #panel a
      pm[pm>75]<-75 #topcode pm color
      pm[pm<=5]<-5 #bottomcode pm color
      
      plot(af_w, border = NA, xlim = c(-17,16.5))
      plot(af_other, border = 'white', lwd = 0.01, col= 'gray92', add = T)
      plot(pm, col = pal_pm,  add= T, legend = F)
      plot(af, border = add.alpha('white', .5), lwd = 0.01, col= NA, add = T)
      
      
      
      #panel b
      pmnd[pmnd<5]<-5
      pmnd[pmnd>61.9]<-75
      plot(af_w, border = NA, xlim = c(-17,16.5))
      plot(af_other, border = 'white', lwd = 0.01, col= 'gray92', add = T)
      plot(pmnd, col = pal_pm, add= T, legend = F)
      plot(af, border = add.alpha('white', .5), lwd = 0.01, col= NA, add = T)
      
      
    
      #panel c
      plot(af_w, border = NA, xlim = c(-17,16.5))
      plot(af_other, border = 'white', lwd = 0.01, col= 'gray92', add = T)
      plot(ur, add = T, col = pal_urca, legend = F)
      plot(af, border = add.alpha('white', .5), lwd = 0.01, col= NA, add = T)
      
      
      
      
      #panel d
      
      plot(af_w, border = NA, xlim = c(-17,16.5))
      plot(af_other, border = 'white', lwd = 0.01, col= 'gray92', add = T)
      points(rwi$longitude, rwi$latitude, pch = 15, cex = 0.06, col= rwi$color2)
      plot(af, border = add.alpha('white', .5), lwd = 0.01, col= NA, add = T)
      
      dev.off()
      
      
      
      
      #############

      
      
      
      
      
      
      
      
      
      
      
      #### LEGENDS ####
      
      
      
      pdf("figures/fig3-a-b-legend.pdf", width = 8, height = 10)
      plot(1,1, col =NA, axes = F, xlab = "",ylab = "", xlim = c(-1, 11), ylim = c(-5,5))
      plotrix::gradient.rect(0,-.25,10,.25,
                             col=(plotrix::smoothColors(colorRampPalette(pal_pm)(500))), 
                             gradient="x", border = 'black', nslices = 100000)      
      segments(x0 = seq(0,10,10/6), y0=-.4, y1 = -.25, lwd = 1)
      
      text(x = seq(0,10,10/6), y = -.7, labels = seq(10,70,10),cex = 1.5)
      
      dev.off()

      
      
      
      
      
      pdf("figures/fig3-c-legend.pdf", width = 8, height = 10)
      plot(1,1, col =NA, axes = F, xlab = "",ylab = "", xlim = c(-1, 11), ylim = c(-5,5))
      plotrix::gradient.rect(0,-.25,10,.25,
                             col=(plotrix::smoothColors(colorRampPalette(pal_urca)(500))), 
                             gradient="x", border = 'black', nslices = 100000)      
      segments(x0 = seq(0,10,10/6), y0=-.4, y1 = -.25, lwd = 1)
      
      text(x = seq(0,10,10/6), y = -.7, labels = c(1,seq(5,30,5)),cex = 1.5)
      
      dev.off()
      
      
      
      
      pdf("figures/fig3-d-legend.pdf", width = 8, height = 10)
      plot(1,1, col =NA, axes = F, xlab = "",ylab = "", xlim = c(-1, 11), ylim = c(-5,5))
      plotrix::gradient.rect(0,-.25,10,.25,
                             col=(plotrix::smoothColors(colorRampPalette(pal_wealth)(500))), 
                             gradient="x", border = 'black', nslices = 100000)      
      
      text(x = c(0,5,10), y = -.7, labels = c("lower", " ","higher"),cex = 1.5)
      
      dev.off()
      
      
      
      
      
########################################################################################################################
      
      #generate scatters for inserts between panels
      
      load("data/west-africa-plot-data.RData")
      
      rwi$pmcell <- cellFromXY(raster(pm), rwi[,c("longitude","latitude")])
      rwi$pm_tot = pm[rwi$pmcell]
      rwi$pm_ant = pmnd[rwi$pmcell]
      rwi$urcell <- cellFromXY(raster(ur), rwi[,c("longitude","latitude")])
      rwi$urca <- ur[rwi$urcell]
      
    
  #### PANEL A #####    
      
      mat <- matrix(nrow =length(0:70)-1, ncol = length(seq(-2, 2, .1)) )
      for(i in 1:(nrow(mat) -1)){
        for(j in 1:(ncol(mat)-1 ) ){
          
          mat[i,j]<-sum(
            rwi$rwi > seq(-2, 2, .1)[j] & rwi$rwi <= seq(-2, 2, .1)[j+1] &
              rwi$pm_tot > (0:70)[i] & rwi$pm_tot <=  (0:70)[i+1]
            , na.rm =T)
        }
      }
      mat[mat==0]<-NA
      
      
      
      
      map_y <- data.frame(y_plot=1:(length(0:70)-1)/(length(0:70)-1), y=0:69)
      map_x <- data.frame(x_plot = 1:(length(seq(-2, 2, .1)))/length(seq(-2, 2, .1)), x = seq(-2, 2, .1))
      new = rwi[,c("rwi","pm_tot")]; names(new) = c("x","y")
      new[,1] = round(new[,1], 1)
      new[,2] = round(new[,2], 0)
      map_x = round(map_x, 1)
      new <- left_join(new, map_x) %>% left_join(map_y)
      
      
      png(file = "figures/fig3-scatter-a-withaxis.png", width = 400, height = 300)
      par(mar = c(4,3,1,0))
      par(mgp = c(3,1.75,0))
      plot(flip(raster(mat), direction= "y"), col = pnw_palette("Winter",100)[100:1], legend = F, box = F, axes = F)
      axis(1, tick = T, at = seq(0,1,.25), labels = seq(-2, 2,1 ), cex.axis = 2,line=1)
      axis(2, tick = T, at = seq(0,1,1/7), labels = seq(0, 70,10 ), las = 2, cex.axis = 2,line=-2)
      
      lines(seq(0, 1, .01), predict(lm(y_plot ~ x_plot, new), newdata = data.frame(x_plot = seq(0, 1, .01)) ), col = 'navy',lwd=5)
      dev.off()
      
      
      
      
  #### PANEL B #####    
      mat <- matrix(nrow =length(0:70)-1, ncol = length(seq(-2, 2, .1)) )
      for(i in 1:(nrow(mat) -1)){
        for(j in 1:(ncol(mat)-1 ) ){
          
          mat[i,j]<-sum(
                        rwi$rwi > seq(-2, 2, .1)[j] & rwi$rwi <= seq(-2, 2, .1)[j+1] &
                        rwi$pm_ant > (0:70)[i] & rwi$pm_ant <=  (0:70)[i+1]
                          , na.rm =T)
        }
      }
      mat[mat==0]<-NA
      
     

      
      map_y <- data.frame(y_plot=1:(length(0:70)-1)/(length(0:70)-1), y=0:69)
      map_x <- data.frame(x_plot = 1:(length(seq(-2, 2, .1)))/length(seq(-2, 2, .1)), x = seq(-2, 2, .1))
      new = rwi[,c("rwi","pm_ant")]; names(new) = c("x","y")
      new[,1] = round(new[,1], 1)
      new[,2] = round(new[,2], 0)
      map_x = round(map_x, 1)
      new <- left_join(new, map_x) %>% left_join(map_y)
      
      
      png(file = "figures/fig3-scatter-b-withaxis.png", width = 400, height = 300)
      par(mar = c(4,3,1,0))
      par(mgp = c(3,1.75,0))
      plot(flip(raster(mat), direction= "y"), col = pnw_palette("Winter",100)[100:1], legend = F, box = F, axes = F)
      axis(1, tick = T, at = seq(0,1,.25), labels = seq(-2, 2,1 ), cex.axis = 2,line=1)
      axis(2, tick = T, at = seq(0,1,1/7), labels = seq(0, 70,10 ), las = 2, cex.axis = 2,line=-2)
      
      lines(seq(0, 1, .01), predict(lm(y_plot ~ x_plot, new), newdata = data.frame(x_plot = seq(0, 1, .01)) ), col = 'navy',lwd=5)
      dev.off()
      
      
      
      
      #### PANEL C #####    
      mat <- matrix(nrow =length(0:30)-1, ncol = length(seq(-2, 2, .1)) )
      for(i in 1:(nrow(mat) -1)){
        for(j in 1:(ncol(mat)-1 ) ){
          
          mat[i,j]<-sum(
            rwi$rwi > seq(-2, 2, .1)[j] & rwi$rwi <= seq(-2, 2, .1)[j+1] &
              rwi$urca > (0:30)[i] & rwi$urca <=  (0:30)[i+1]
            , na.rm =T)
        }
      }
      mat[mat==0]<-NA
      
      
      
      
      map_y <- data.frame(y_plot=1:(length(0:30)-1)/(length(0:30)-1), y=1:30)
      map_x <- data.frame(x_plot = 1:(length(seq(-2, 2, .1)))/length(seq(-2, 2, .1)), x = seq(-2, 2, .1))
      new = rwi[,c("rwi","urca")]; names(new) = c("x","y")
      new[,1] = round(new[,1], 1)
      new[,2] = round(new[,2], 0)
      map_x = round(map_x, 1)
      new <- left_join(new, map_x) %>% left_join(map_y)
      
      
      png(file = "figures/fig3-scatter-c-withaxis.png", width = 400, height = 300)
      par(mar = c(4,3,1,0))
      par(mgp = c(3,1.75,0))
      plot((raster(mat)), col = pnw_palette("Winter",100)[100:1], legend = F, box = F, axes = F)
      axis(1, tick = T, at = seq(0,1,.25), labels = seq(-2, 2,1 ), cex.axis = 2,line=1)
      axis(2, tick = T, at = seq(0,1,1/7), labels = seq(0, 70,10 ), las = 2, cex.axis = 2,line=-2)
      
      lines(seq(0, 1, .01), 1-predict(lm(y_plot ~ x_plot, new), newdata = data.frame(x_plot = seq(0, 1, .01)) ), col = 'navy',lwd=5)
      dev.off()
      
      
      
     
      
      
